{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Validation of GPT2 traied on TinyStories\n",
    "\n",
    "Load the model and tokenizer\n",
    "Model checkpoint from training `gpt2_tinystories.py` : `gpt2-tinystories-final`\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model loaded on:  cuda\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import (\n",
    "    GPT2LMHeadModel,\n",
    "    GPT2TokenizerFast,\n",
    ")\n",
    "\n",
    "# Load tokenizer\n",
    "tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "\n",
    "# Load model and move to GPU\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "model = GPT2LMHeadModel.from_pretrained('gpt2-tinystories-final')\n",
    "model.to(device)\n",
    "\n",
    "print(\"Model loaded on: \", device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Generate sample stories given prompts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Testing model generation:\n",
      "\n",
      "Prompt: Once upon a time\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated text: Once upon a time, there was a little girl named Lily. She loved to play outside in the park with her friends. One day, Lily and her friend Jack were playing hide-and-seek. Jack was hiding behind a big tree when he heard a loud noise. \n",
      "\n",
      "\"What was that?\" asked Jack.\n",
      "   \"I don't know,\" replied Lily, \"but it sounds like someone's house is on fire.\" \n",
      "\n",
      " Jack and Lily quickly ran to the nearby\n",
      "\n",
      "Prompt: The little dog\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated text: The little dog was running around the yard and he saw a cable lying on the ground. He picked it up and started to play with it. Suddenly, he heard a loud noise. It was his mom calling him for lunch.\n",
      "\n",
      "He ran to the kitchen to get something to eat. But when he got there, his dog's mom had already gone. She had left the cable on her desk. \n",
      "The dog didn't want to leave the cord alone. So he started running away\n",
      "\n",
      "Prompt: In the garden\n",
      "Generated text: In the garden, there was a little girl named Lucy. She had a toy garden with lots of flowers and vegetables. Every day, she would take her toy shovel and dig in the dirt.\n",
      "\n",
      "One day Lucy found something very special. It was an old box with a mysterious lid. Lucy opened the lid and saw lots and lots inside. Inside were lots, but not all. There were many different kinds of vegetables, too. \n",
      "   Lucy picked up a few of the vegetables\n"
     ]
    }
   ],
   "source": [
    "# Test generation function\n",
    "def generate_text(prompt, max_length=100):\n",
    "    inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
    "    attention_mask = inputs['attention_mask'].to(model.device)\n",
    "    outputs = model.generate(\n",
    "        inputs.input_ids,\n",
    "        attention_mask=attention_mask,\n",
    "        max_length=max_length,\n",
    "        num_return_sequences=1,\n",
    "        no_repeat_ngram_size=2,\n",
    "        do_sample=True,\n",
    "        top_k=50,\n",
    "        top_p=0.95,\n",
    "        temperature=0.7\n",
    "    )\n",
    "    return tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
    "\n",
    "# Test the model\n",
    "test_prompts = [\n",
    "    \"Once upon a time\",\n",
    "    \"The little dog\",\n",
    "    \"In the garden\"\n",
    "]\n",
    "\n",
    "print(\"\\nTesting model generation:\")\n",
    "for prompt in test_prompts:\n",
    "    print(f\"\\nPrompt: {prompt}\")\n",
    "    print(f\"Generated text: {generate_text(prompt)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Use an interactive demo\n",
    "\n",
    "This will prompt the user for the beginning of a story and then generate the rest of the story.\n",
    "Type `exit` to quit the demo.\n",
    "\n",
    "```python"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated text: In the middle of the road was a beautiful beach. The sand was soft and the sun shone down on it.\n",
      "\n",
      "Two children, Max and Lily, were playing in the sand. They ran and laughed as they built castles, dug tunnels and even built a sandcastle. \n",
      "Max had a bucket of water and he was pouring it on the ground. Lily laughed and said \"It's so funny!\"\n",
      "  Max was playing with his bucket and having a great time. He was having\n"
     ]
    }
   ],
   "source": [
    "# make an interactive prompt that asks for user input\n",
    "def interactive_prompt():\n",
    "    while True:\n",
    "        prompt = input(\"\\nEnter a prompt: \")\n",
    "        if prompt.lower() == \"exit\":\n",
    "            break\n",
    "        print(f\"Generated text: {generate_text(prompt)}\")\n",
    "        \n",
    "interactive_prompt()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluate the model on a validation set\n",
    "\n",
    "Using perplexity as the evaluation metric\n",
    "\n",
    "```python"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f4de6198cb4b43758e72660197a99bc2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "IntProgress(value=0, description='Evaluating', max=2749)"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average Loss: 10.3397\n",
      "Perplexity: 86240.4700\n"
     ]
    }
   ],
   "source": [
    "import math\n",
    "from transformers import DataCollatorForLanguageModeling\n",
    "from torch.utils.data import DataLoader\n",
    "from datasets import load_dataset\n",
    "from IPython.display import display\n",
    "import ipywidgets as widgets\n",
    "\n",
    "\n",
    "dataset = load_dataset(\"roneneldan/TinyStories\")\n",
    "\n",
    "def tokenize_function(examples):\n",
    "    \"\"\"Tokenize dataset examples.\"\"\"\n",
    "    return tokenizer(\n",
    "        examples[\"text\"],\n",
    "        truncation=True,\n",
    "        max_length=512,\n",
    "        padding=\"max_length\"\n",
    "    )\n",
    "\n",
    "# Tokenize test dataset\n",
    "tokenized_test = dataset[\"validation\"].map(\n",
    "    tokenize_function,\n",
    "    batched=True,\n",
    "    remove_columns=dataset[\"validation\"].column_names\n",
    ")\n",
    "\n",
    "# Create data collator\n",
    "data_collator = DataCollatorForLanguageModeling(\n",
    "    tokenizer=tokenizer,\n",
    "    mlm=False\n",
    ")\n",
    "\n",
    "# Create dataloader\n",
    "test_dataloader = DataLoader(\n",
    "    tokenized_test,\n",
    "    batch_size=8,\n",
    "    collate_fn=data_collator\n",
    ")\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    \"\"\"Compute perplexity metric.\"\"\"\n",
    "    logits, labels = eval_pred\n",
    "    shift_logits = logits[..., :-1, :].contiguous()\n",
    "    shift_labels = labels[..., 1:].contiguous()\n",
    "    \n",
    "    loss = torch.nn.functional.cross_entropy(\n",
    "        torch.from_numpy(shift_logits.reshape(-1, shift_logits.shape[-1])),\n",
    "        torch.from_numpy(shift_labels.reshape(-1))\n",
    "    )\n",
    "    \n",
    "    try:\n",
    "        perplexity = math.exp(loss.item())\n",
    "    except OverflowError:\n",
    "        perplexity = float(\"inf\")\n",
    "    \n",
    "    return {\"perplexity\": perplexity}\n",
    "\n",
    "###################\n",
    "\n",
    "device = model.device\n",
    "# Evaluation loop\n",
    "total_loss = 0\n",
    "total_perplexity = 0\n",
    "counter = 0\n",
    "progress_bar = widgets.IntProgress(min=0, max=len(test_dataloader), description='Evaluating')\n",
    "display(progress_bar)\n",
    "\n",
    "with torch.no_grad():\n",
    "    for i, batch in enumerate(test_dataloader):\n",
    "        input_ids = batch['input_ids'].to(device)\n",
    "        attention_mask = batch['attention_mask'].to(device) \n",
    "        labels = input_ids.clone()\n",
    "        \n",
    "        # Calculate perplexity\n",
    "        with torch.no_grad():\n",
    "            outputs = model(input_ids, labels=input_ids)\n",
    "            loss = outputs.loss\n",
    "            total_loss += loss.item()\n",
    "            total_perplexity += math.exp(loss.item())\n",
    "            counter += 1\n",
    "            \n",
    "        progress_bar.value += 1\n",
    "       \n",
    "# Calculate average loss and perplexity\n",
    "avg_loss = total_loss / counter\n",
    "perplexity = total_perplexity / counter\n",
    "\n",
    "print(f\"Average Loss: {avg_loss:.4f}\")\n",
    "print(f\"Perplexity: {perplexity:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Text: Once upon a time\n",
      "Perplexity: 1.01\n",
      "\n",
      "Text: The little dog\n",
      "Perplexity: 38.56\n",
      "\n",
      "Text: In the garden\n",
      "Perplexity: 5.08\n"
     ]
    }
   ],
   "source": [
    "def calculate_perplexity(text, model, tokenizer, device=\"cuda\"):\n",
    "    # Encode text\n",
    "    encodings = tokenizer(\n",
    "        text,\n",
    "        return_tensors=\"pt\",\n",
    "        truncation=True,\n",
    "        max_length=512\n",
    "    )\n",
    "    \n",
    "    # Move to device\n",
    "    input_ids = encodings[\"input_ids\"].to(device)\n",
    "    \n",
    "    # Calculate perplexity\n",
    "    with torch.no_grad():\n",
    "        outputs = model(input_ids, labels=input_ids)\n",
    "        loss = outputs.loss\n",
    "    \n",
    "    return math.exp(loss.item())\n",
    "\n",
    "# Test the model\n",
    "texts = [\n",
    "    \"Once upon a time\",\n",
    "    \"The little dog\",\n",
    "    \"In the garden\"\n",
    "]\n",
    "\n",
    "# Calculate perplexity for each text\n",
    "for text in texts:\n",
    "    perplexity = calculate_perplexity(text, model, tokenizer, device)\n",
    "    print(f\"\\nText: {text}\")\n",
    "    print(f\"Perplexity: {perplexity:.2f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "app",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
